!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief Distribution methods for atoms, particles, or molecules
!> \par History
!>      - 1d-distribution of molecules and particles (Sep. 2003, MK)
!>      - 2d-distribution for Quickstep updated with molecules (Oct. 2003, MK)
!> \author MK (22.08.2003)
! **************************************************************************************************
MODULE distribution_methods
   USE atomic_kind_types,               ONLY: atomic_kind_type,&
                                              get_atomic_kind,&
                                              get_atomic_kind_set
   USE basis_set_types,                 ONLY: get_gto_basis_set,&
                                              gto_basis_set_type
   USE cell_types,                      ONLY: cell_type,&
                                              pbc,&
                                              real_to_scaled,&
                                              scaled_to_real
   USE cp_array_utils,                  ONLY: cp_1d_i_p_type
   USE cp_blacs_env,                    ONLY: cp_blacs_env_type
   USE cp_dbcsr_api,                    ONLY: dbcsr_distribution_get_num_images
   USE cp_log_handling,                 ONLY: cp_get_default_logger,&
                                              cp_logger_get_default_io_unit,&
                                              cp_logger_get_default_unit_nr,&
                                              cp_logger_type
   USE cp_min_heap,                     ONLY: cp_heap_fill,&
                                              cp_heap_get_first,&
                                              cp_heap_new,&
                                              cp_heap_release,&
                                              cp_heap_reset_first,&
                                              cp_heap_type
   USE cp_output_handling,              ONLY: cp_p_file,&
                                              cp_print_key_finished_output,&
                                              cp_print_key_should_output,&
                                              cp_print_key_unit_nr
   USE distribution_1d_types,           ONLY: distribution_1d_create,&
                                              distribution_1d_type
   USE distribution_2d_types,           ONLY: distribution_2d_create,&
                                              distribution_2d_type,&
                                              distribution_2d_write
   USE input_constants,                 ONLY: model_block_count,&
                                              model_block_lmax
   USE input_section_types,             ONLY: section_vals_get_subs_vals,&
                                              section_vals_type,&
                                              section_vals_val_get
   USE kinds,                           ONLY: dp,&
                                              int_8
   USE machine,                         ONLY: m_flush
   USE mathconstants,                   ONLY: pi
   USE mathlib,                         ONLY: gcd,&
                                              lcm
   USE molecule_kind_types,             ONLY: get_molecule_kind,&
                                              get_molecule_kind_set,&
                                              molecule_kind_type
   USE molecule_types,                  ONLY: molecule_type
   USE parallel_rng_types,              ONLY: UNIFORM,&
                                              rng_stream_type
   USE particle_types,                  ONLY: particle_type
   USE qs_kind_types,                   ONLY: get_qs_kind,&
                                              qs_kind_type
   USE util,                            ONLY: sort
#include "./base/base_uses.f90"

   IMPLICIT NONE

   PRIVATE

! *** Global parameters (in this module) ***

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'distribution_methods'

! *** Public subroutines ***

   PUBLIC :: distribute_molecules_1d, &
             distribute_molecules_2d

CONTAINS

! **************************************************************************************************
!> \brief Distribute molecules and particles
!> \param atomic_kind_set particle (atomic) kind information
!> \param particle_set particle information
!> \param local_particles distribution of particles created by this routine
!> \param molecule_kind_set molecule kind information
!> \param molecule_set molecule information
!> \param local_molecules distribution of molecules created by this routine
!> \param force_env_section ...
!> \param prev_molecule_kind_set previous molecule kind information, used with
!>        prev_local_molecules
!> \param prev_local_molecules previous distribution of molecules, new one will
!>        be identical if all the prev_* arguments are present and associated
!> \par History
!>      none
!> \author MK (Jun. 2003)
! **************************************************************************************************
   SUBROUTINE distribute_molecules_1d(atomic_kind_set, particle_set, &
                                      local_particles, &
                                      molecule_kind_set, molecule_set, &
                                      local_molecules, force_env_section, &
                                      prev_molecule_kind_set, &
                                      prev_local_molecules)

      TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
      TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
      TYPE(distribution_1d_type), POINTER                :: local_particles
      TYPE(molecule_kind_type), DIMENSION(:), POINTER    :: molecule_kind_set
      TYPE(molecule_type), DIMENSION(:), POINTER         :: molecule_set
      TYPE(distribution_1d_type), POINTER                :: local_molecules
      TYPE(section_vals_type), POINTER                   :: force_env_section
      TYPE(molecule_kind_type), DIMENSION(:), OPTIONAL, &
         POINTER                                         :: prev_molecule_kind_set
      TYPE(distribution_1d_type), OPTIONAL, POINTER      :: prev_local_molecules

      CHARACTER(len=*), PARAMETER :: routineN = 'distribute_molecules_1d'

      INTEGER :: atom_a, bin, handle, iatom, imolecule, imolecule_kind, imolecule_local, &
         imolecule_prev_kind, iparticle_kind, ipe, iw, kind_a, molecule_a, n, natom, nbins, nload, &
         nmolecule, nmolecule_kind, nparticle_kind, nsgf, output_unit
      INTEGER(int_8)                                     :: bin_price
      INTEGER(int_8), ALLOCATABLE, DIMENSION(:)          :: workload_count, workload_fill
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: nmolecule_local, nparticle_local, work
      INTEGER, DIMENSION(:), POINTER                     :: molecule_list
      LOGICAL                                            :: found, has_prev_subsys_info, is_local
      TYPE(cp_1d_i_p_type), ALLOCATABLE, DIMENSION(:)    :: local_molecule
      TYPE(cp_heap_type)                                 :: bin_heap_count, bin_heap_fill
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(molecule_kind_type), POINTER                  :: molecule_kind

      CALL timeset(routineN, handle)

      has_prev_subsys_info = .FALSE.
      IF (PRESENT(prev_local_molecules) .AND. &
          PRESENT(prev_molecule_kind_set)) THEN
         IF (ASSOCIATED(prev_local_molecules) .AND. &
             ASSOCIATED(prev_molecule_kind_set)) THEN
            has_prev_subsys_info = .TRUE.
         END IF
      END IF

      logger => cp_get_default_logger()

      ASSOCIATE (group => logger%para_env, mype => logger%para_env%mepos + 1, &
                 npe => logger%para_env%num_pe)

         ALLOCATE (workload_count(npe))
         workload_count(:) = 0

         ALLOCATE (workload_fill(npe))
         workload_fill(:) = 0

         nmolecule_kind = SIZE(molecule_kind_set)

         ALLOCATE (nmolecule_local(nmolecule_kind))
         nmolecule_local(:) = 0

         ALLOCATE (local_molecule(nmolecule_kind))

         nparticle_kind = SIZE(atomic_kind_set)

         ALLOCATE (nparticle_local(nparticle_kind))
         nparticle_local(:) = 0

         nbins = npe

         CALL cp_heap_new(bin_heap_count, nbins)
         CALL cp_heap_fill(bin_heap_count, workload_count)

         CALL cp_heap_new(bin_heap_fill, nbins)
         CALL cp_heap_fill(bin_heap_fill, workload_fill)

         DO imolecule_kind = 1, nmolecule_kind

            molecule_kind => molecule_kind_set(imolecule_kind)

            NULLIFY (molecule_list)

!     *** Get the number of molecules and the number of ***
!     *** atoms in each molecule of that molecular kind ***

            CALL get_molecule_kind(molecule_kind=molecule_kind, &
                                   molecule_list=molecule_list, &
                                   natom=natom, &
                                   nsgf=nsgf)

!     *** Consider the number of atoms or basis ***
!     *** functions which depends on the method ***

            nload = MAX(natom, nsgf)
            nmolecule = SIZE(molecule_list)

!     *** Get the number of local molecules of the current molecule kind ***

            DO imolecule = 1, nmolecule
               IF (has_prev_subsys_info) THEN
                  DO imolecule_prev_kind = 1, SIZE(prev_molecule_kind_set)
                     IF (ANY(prev_local_molecules%list(imolecule_prev_kind)%array( &
                             1:prev_local_molecules%n_el(imolecule_prev_kind)) == molecule_list(imolecule))) THEN
                        ! molecule used to be local
                        nmolecule_local(imolecule_kind) = nmolecule_local(imolecule_kind) + 1
                     END IF
                  END DO
               ELSE
                  CALL cp_heap_get_first(bin_heap_count, bin, bin_price, found)
                  IF (.NOT. found) &
                     CPABORT("No topmost heap element found.")

                  ipe = bin
                  IF (bin_price /= workload_count(ipe)) &
                     CPABORT("inconsistent heap")

                  workload_count(ipe) = workload_count(ipe) + nload
                  IF (ipe == mype) THEN
                     nmolecule_local(imolecule_kind) = nmolecule_local(imolecule_kind) + 1
                  END IF

                  bin_price = workload_count(ipe)
                  CALL cp_heap_reset_first(bin_heap_count, bin_price)
               END IF
            END DO

!     *** Distribute the molecules ***
            n = nmolecule_local(imolecule_kind)

            IF (n > 0) THEN
               ALLOCATE (local_molecule(imolecule_kind)%array(n))
            ELSE
               NULLIFY (local_molecule(imolecule_kind)%array)
            END IF

            imolecule_local = 0
            DO imolecule = 1, nmolecule
               is_local = .FALSE.
               IF (has_prev_subsys_info) THEN
                  DO imolecule_prev_kind = 1, SIZE(prev_molecule_kind_set)
                     IF (ANY(prev_local_molecules%list(imolecule_prev_kind)%array( &
                             1:prev_local_molecules%n_el(imolecule_prev_kind)) == molecule_list(imolecule))) THEN
                        is_local = .TRUE.
                     END IF
                  END DO
               ELSE
                  CALL cp_heap_get_first(bin_heap_fill, bin, bin_price, found)
                  IF (.NOT. found) &
                     CPABORT("No topmost heap element found.")

                  ipe = bin
                  IF (bin_price /= workload_fill(ipe)) &
                     CPABORT("inconsistent heap")

                  workload_fill(ipe) = workload_fill(ipe) + nload
                  is_local = (ipe == mype)
               END IF
               IF (is_local) THEN
                  imolecule_local = imolecule_local + 1
                  molecule_a = molecule_list(imolecule)
                  local_molecule(imolecule_kind)%array(imolecule_local) = molecule_a
                  DO iatom = 1, natom
                     atom_a = molecule_set(molecule_a)%first_atom + iatom - 1

                     CALL get_atomic_kind(atomic_kind=particle_set(atom_a)%atomic_kind, &
                                          kind_number=kind_a)
                     nparticle_local(kind_a) = nparticle_local(kind_a) + 1
                  END DO
               END IF
               IF (.NOT. has_prev_subsys_info) THEN
                  bin_price = workload_fill(ipe)
                  CALL cp_heap_reset_first(bin_heap_fill, bin_price)
               END IF
            END DO

         END DO

         IF (ANY(workload_fill /= workload_count)) &
            CPABORT("Inconsistent heaps encountered")

         CALL cp_heap_release(bin_heap_count)
         CALL cp_heap_release(bin_heap_fill)

!   *** Create the local molecule structure ***

         CALL distribution_1d_create(local_molecules, &
                                     n_el=nmolecule_local, &
                                     para_env=logger%para_env)

!   *** Create the local particle structure ***

         CALL distribution_1d_create(local_particles, &
                                     n_el=nparticle_local, &
                                     para_env=logger%para_env)

!   *** Store the generated local molecule and particle distributions ***

         nparticle_local(:) = 0

         DO imolecule_kind = 1, nmolecule_kind

            IF (nmolecule_local(imolecule_kind) == 0) CYCLE

            local_molecules%list(imolecule_kind)%array(:) = &
               local_molecule(imolecule_kind)%array(:)

            molecule_kind => molecule_kind_set(imolecule_kind)

            CALL get_molecule_kind(molecule_kind=molecule_kind, &
                                   natom=natom)

            DO imolecule = 1, nmolecule_local(imolecule_kind)
               molecule_a = local_molecule(imolecule_kind)%array(imolecule)
               DO iatom = 1, natom
                  atom_a = molecule_set(molecule_a)%first_atom + iatom - 1
                  CALL get_atomic_kind(atomic_kind=particle_set(atom_a)%atomic_kind, &
                                       kind_number=kind_a)
                  nparticle_local(kind_a) = nparticle_local(kind_a) + 1
                  local_particles%list(kind_a)%array(nparticle_local(kind_a)) = atom_a
               END DO
            END DO

         END DO

!   *** Print distribution, if requested ***

         IF (BTEST(cp_print_key_should_output(logger%iter_info, &
                                              force_env_section, "PRINT%DISTRIBUTION1D"), cp_p_file)) THEN

            output_unit = cp_print_key_unit_nr(logger, force_env_section, "PRINT%DISTRIBUTION1D", &
                                               extension=".Log")

            iw = output_unit
            IF (output_unit < 0) iw = cp_logger_get_default_unit_nr(logger, LOCAL=.TRUE.)

!     *** Print molecule distribution ***

            ALLOCATE (work(npe))
            work(:) = 0

            work(mype) = SUM(nmolecule_local)
            CALL group%sum(work)

            IF (output_unit > 0) THEN
               WRITE (UNIT=output_unit, &
                      FMT="(/, T2, A, T51, A, /, (T52, I6, T73, I8))") &
                  "DISTRIBUTION OF THE MOLECULES", &
                  "Process    Number of molecules", &
                  (ipe - 1, work(ipe), ipe=1, npe)
               WRITE (UNIT=output_unit, FMT="(T55, A3, T73, I8)") &
                  "Sum", SUM(work)
               CALL m_flush(output_unit)
            END IF

            CALL group%sync()

            DO ipe = 1, npe
               IF (ipe == mype) THEN
                  WRITE (UNIT=iw, FMT="(/, T3, A)") &
                     "Process   Kind   Local molecules (global indices)"
                  DO imolecule_kind = 1, nmolecule_kind
                     IF (imolecule_kind == 1) THEN
                        WRITE (UNIT=iw, FMT="(T4, I6, 2X, I5, (T21, 10I6))") &
                           ipe - 1, imolecule_kind, &
                           (local_molecules%list(imolecule_kind)%array(imolecule), &
                            imolecule=1, nmolecule_local(imolecule_kind))
                     ELSE
                        WRITE (UNIT=iw, FMT="(T12, I5, (T21, 10I6))") &
                           imolecule_kind, &
                           (local_molecules%list(imolecule_kind)%array(imolecule), &
                            imolecule=1, nmolecule_local(imolecule_kind))
                     END IF
                  END DO
               END IF
               CALL m_flush(iw)
               CALL group%sync()
            END DO

!     *** Print particle distribution ***

            work(:) = 0

            work(mype) = SUM(nparticle_local)
            CALL group%sum(work)

            IF (output_unit > 0) THEN
               WRITE (UNIT=output_unit, &
                      FMT="(/, T2, A, T51, A, /, (T52, I6, T73, I8))") &
                  "DISTRIBUTION OF THE PARTICLES", &
                  "Process    Number of particles", &
                  (ipe - 1, work(ipe), ipe=1, npe)
               WRITE (UNIT=output_unit, FMT="(T55, A3, T73, I8)") &
                  "Sum", SUM(work)
               CALL m_flush(output_unit)
            END IF

            CALL group%sync()

            DO ipe = 1, npe
               IF (ipe == mype) THEN
                  WRITE (UNIT=iw, FMT="(/, T3, A)") &
                     "Process   Kind   Local particles (global indices)"
                  DO iparticle_kind = 1, nparticle_kind
                     IF (iparticle_kind == 1) THEN
                        WRITE (UNIT=iw, FMT="(T4, I6, 2X, I5, (T20, 10I6))") &
                           ipe - 1, iparticle_kind, &
                           (local_particles%list(iparticle_kind)%array(iatom), &
                            iatom=1, nparticle_local(iparticle_kind))
                     ELSE
                        WRITE (UNIT=iw, FMT="(T12, I5, (T20, 10I6))") &
                           iparticle_kind, &
                           (local_particles%list(iparticle_kind)%array(iatom), &
                            iatom=1, nparticle_local(iparticle_kind))
                     END IF
                  END DO
               END IF
               CALL m_flush(iw)
               CALL group%sync()
            END DO
            DEALLOCATE (work)

            CALL cp_print_key_finished_output(output_unit, logger, force_env_section, &
                                              "PRINT%DISTRIBUTION1D")
         END IF
      END ASSOCIATE
!   *** Release work storage ***

      DEALLOCATE (workload_count)

      DEALLOCATE (workload_fill)

      DEALLOCATE (nmolecule_local)

      DEALLOCATE (nparticle_local)

      DO imolecule_kind = 1, nmolecule_kind
         IF (ASSOCIATED(local_molecule(imolecule_kind)%array)) THEN
            DEALLOCATE (local_molecule(imolecule_kind)%array)
         END IF
      END DO
      DEALLOCATE (local_molecule)

      CALL timestop(handle)

   END SUBROUTINE distribute_molecules_1d

! **************************************************************************************************
!> \brief Distributes the particle pairs creating a 2d distribution optimally
!>      suited for quickstep
!> \param cell ...
!> \param atomic_kind_set ...
!> \param particle_set ...
!> \param qs_kind_set ...
!> \param molecule_kind_set ...
!> \param molecule_set ...
!> \param distribution_2d the distribution that will be created by this
!>                         method
!> \param blacs_env the parallel environment at the basis of the
!>                   distribution
!> \param force_env_section ...
!> \par History
!>      - local_rows & cols blocksize optimizations (Aug. 2003, MK)
!>      - cleanup of distribution_2d (Sep. 2003, fawzi)
!>      - update for molecules (Oct. 2003, MK)
!> \author fawzi (Feb. 2003)
!> \note
!>      Intermediate generation of a 2d distribution of the molecules, but
!>      only the corresponding particle (atomic) distribution is currently
!>      used. The 2d distribution of the molecules is deleted, but may easily
!>      be recovered (MK).
! **************************************************************************************************
   SUBROUTINE distribute_molecules_2d(cell, atomic_kind_set, particle_set, &
                                      qs_kind_set, molecule_kind_set, molecule_set, &
                                      distribution_2d, blacs_env, force_env_section)
      TYPE(cell_type), POINTER                           :: cell
      TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
      TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
      TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
      TYPE(molecule_kind_type), DIMENSION(:), POINTER    :: molecule_kind_set
      TYPE(molecule_type), DIMENSION(:), POINTER         :: molecule_set
      TYPE(distribution_2d_type), POINTER                :: distribution_2d
      TYPE(cp_blacs_env_type), POINTER                   :: blacs_env
      TYPE(section_vals_type), POINTER                   :: force_env_section

      CHARACTER(len=*), PARAMETER :: routineN = 'distribute_molecules_2d'

      INTEGER :: cluster_price, cost_model, handle, iatom, iatom_mol, iatom_one, ikind, imol, &
         imolecule, imolecule_kind, iparticle_kind, ipcol, iprow, iw, kind_a, n, natom, natom_mol, &
         nclusters, nmolecule, nmolecule_kind, nparticle_kind, nsgf, output_unit
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: cluster_list, cluster_prices, &
                                                            nparticle_local_col, &
                                                            nparticle_local_row, work
      INTEGER, DIMENSION(:), POINTER                     :: lmax_basis, molecule_list
      INTEGER, DIMENSION(:, :), POINTER                  :: cluster_col_distribution, &
                                                            cluster_row_distribution, &
                                                            col_distribution, row_distribution
      LOGICAL :: basic_cluster_optimization, basic_optimization, basic_spatial_optimization, &
         molecular_distribution, skip_optimization
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: coords, pbc_scaled_coords
      REAL(KIND=dp), DIMENSION(3)                        :: center
      TYPE(cp_1d_i_p_type), DIMENSION(:), POINTER        :: local_particle_col, local_particle_row
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(gto_basis_set_type), POINTER                  :: orb_basis_set
      TYPE(molecule_kind_type), POINTER                  :: molecule_kind
      TYPE(section_vals_type), POINTER                   :: distribution_section

!...

      CALL timeset(routineN, handle)

      logger => cp_get_default_logger()

      distribution_section => section_vals_get_subs_vals(force_env_section, "DFT%QS%DISTRIBUTION")

      CALL section_vals_val_get(distribution_section, "2D_MOLECULAR_DISTRIBUTION", l_val=molecular_distribution)
      CALL section_vals_val_get(distribution_section, "SKIP_OPTIMIZATION", l_val=skip_optimization)
      CALL section_vals_val_get(distribution_section, "BASIC_OPTIMIZATION", l_val=basic_optimization)
      CALL section_vals_val_get(distribution_section, "BASIC_SPATIAL_OPTIMIZATION", l_val=basic_spatial_optimization)
      CALL section_vals_val_get(distribution_section, "BASIC_CLUSTER_OPTIMIZATION", l_val=basic_cluster_optimization)

      CALL section_vals_val_get(distribution_section, "COST_MODEL", i_val=cost_model)
      !

      ASSOCIATE (group => blacs_env%para_env, myprow => blacs_env%mepos(1) + 1, mypcol => blacs_env%mepos(2) + 1, &
                 nprow => blacs_env%num_pe(1), npcol => blacs_env%num_pe(2))

         nmolecule_kind = SIZE(molecule_kind_set)
         CALL get_molecule_kind_set(molecule_kind_set, nmolecule=nmolecule)

         nparticle_kind = SIZE(atomic_kind_set)
         CALL get_atomic_kind_set(atomic_kind_set=atomic_kind_set, natom=natom)

         !
         ! we need to generate two representations of the distribution, one as a straight array with global particles
         ! one ordered wrt to kinds and only listing the local particles
         !
         ALLOCATE (row_distribution(natom, 2))
         ALLOCATE (col_distribution(natom, 2))
         ! Initialize the distributions to -1, as the second dimension only gets set with cluster optimization
         ! but the information is needed by dbcsr
         row_distribution = -1; col_distribution = -1

         ALLOCATE (local_particle_col(nparticle_kind))
         ALLOCATE (local_particle_row(nparticle_kind))
         ALLOCATE (nparticle_local_row(nparticle_kind))
         ALLOCATE (nparticle_local_col(nparticle_kind))

         IF (basic_optimization .OR. basic_spatial_optimization .OR. basic_cluster_optimization) THEN

            IF (molecular_distribution) THEN
               nclusters = nmolecule
            ELSE
               nclusters = natom
            END IF

            ALLOCATE (cluster_list(nclusters))
            ALLOCATE (cluster_prices(nclusters))
            ALLOCATE (cluster_row_distribution(nclusters, 2))
            ALLOCATE (cluster_col_distribution(nclusters, 2))
            cluster_row_distribution = -1; cluster_col_distribution = -1

            ! Fill in the clusters and their prices
            CALL section_vals_val_get(distribution_section, "COST_MODEL", i_val=cost_model)
            IF (.NOT. molecular_distribution) THEN
               DO iatom = 1, natom
                  IF (iatom > nclusters) &
                     CPABORT("Bounds error")
                  CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
                  cluster_list(iatom) = iatom
                  SELECT CASE (cost_model)
                  CASE (model_block_count)
                     CALL get_qs_kind(qs_kind_set(ikind), nsgf=nsgf)
                     cluster_price = nsgf
                  CASE (model_block_lmax)
                     CALL get_qs_kind(qs_kind_set(ikind), basis_set=orb_basis_set)
                     CALL get_gto_basis_set(orb_basis_set, lmax=lmax_basis)
                     cluster_price = MAXVAL(lmax_basis)
                  CASE default
                     CALL get_qs_kind(qs_kind_set(ikind), basis_set=orb_basis_set)
                     CALL get_gto_basis_set(orb_basis_set, lmax=lmax_basis)
                     cluster_price = 8 + (MAXVAL(lmax_basis)**2)
                  END SELECT
                  cluster_prices(iatom) = cluster_price
               END DO
            ELSE
               imol = 0
               DO imolecule_kind = 1, nmolecule_kind
                  molecule_kind => molecule_kind_set(imolecule_kind)
                  CALL get_molecule_kind(molecule_kind=molecule_kind, molecule_list=molecule_list, natom=natom_mol)
                  DO imolecule = 1, SIZE(molecule_list)
                     imol = imol + 1
                     cluster_list(imol) = imol
                     cluster_price = 0
                     DO iatom_mol = 1, natom_mol
                        iatom = molecule_set(molecule_list(imolecule))%first_atom + iatom_mol - 1
                        CALL get_atomic_kind(particle_set(iatom)%atomic_kind, kind_number=ikind)
                        SELECT CASE (cost_model)
                        CASE (model_block_count)
                           CALL get_qs_kind(qs_kind_set(ikind), nsgf=nsgf)
                           cluster_price = cluster_price + nsgf
                        CASE (model_block_lmax)
                           CALL get_qs_kind(qs_kind_set(ikind), basis_set=orb_basis_set)
                           CALL get_gto_basis_set(orb_basis_set, lmax=lmax_basis)
                           cluster_price = cluster_price + MAXVAL(lmax_basis)
                        CASE default
                           CALL get_qs_kind(qs_kind_set(ikind), basis_set=orb_basis_set)
                           CALL get_gto_basis_set(orb_basis_set, lmax=lmax_basis)
                           cluster_price = cluster_price + 8 + (MAXVAL(lmax_basis)**2)
                        END SELECT
                     END DO
                     cluster_prices(imol) = cluster_price
                  END DO
               END DO
            END IF

            ! And distribute
            IF (basic_optimization) THEN
               CALL make_basic_distribution(cluster_list, cluster_prices, &
                                            nprow, cluster_row_distribution(:, 1), npcol, cluster_col_distribution(:, 1))
            ELSE
               IF (basic_cluster_optimization) THEN
                  IF (molecular_distribution) &
                     CPABORT("clustering and molecular blocking NYI")
                  ALLOCATE (pbc_scaled_coords(3, natom), coords(3, natom))
                  DO iatom = 1, natom
                     CALL real_to_scaled(pbc_scaled_coords(:, iatom), pbc(particle_set(iatom)%r(:), cell), cell)
                     coords(:, iatom) = pbc(particle_set(iatom)%r(:), cell)
                  END DO
                  CALL make_cluster_distribution(coords, pbc_scaled_coords, cell, cluster_prices, &
                                                 nprow, cluster_row_distribution, npcol, cluster_col_distribution)
               ELSE ! basic_spatial_optimization
                  ALLOCATE (pbc_scaled_coords(3, nclusters))
                  IF (.NOT. molecular_distribution) THEN
                     ! just scaled coords
                     DO iatom = 1, natom
                        CALL real_to_scaled(pbc_scaled_coords(:, iatom), pbc(particle_set(iatom)%r(:), cell), cell)
                     END DO
                  ELSE
                     ! use scaled coords of geometric center, folding when appropriate
                     imol = 0
                     DO imolecule_kind = 1, nmolecule_kind
                        molecule_kind => molecule_kind_set(imolecule_kind)
                        CALL get_molecule_kind(molecule_kind=molecule_kind, molecule_list=molecule_list, natom=natom_mol)
                        DO imolecule = 1, SIZE(molecule_list)
                           imol = imol + 1
                           iatom_one = molecule_set(molecule_list(imolecule))%first_atom
                           center = 0.0_dp
                           DO iatom_mol = 1, natom_mol
                              iatom = molecule_set(molecule_list(imolecule))%first_atom + iatom_mol - 1
                              center = center + &
                                   pbc(particle_set(iatom)%r(:) - particle_set(iatom_one)%r(:), cell) + particle_set(iatom_one)%r(:)
                           END DO
                           center = center/natom_mol
                           CALL real_to_scaled(pbc_scaled_coords(:, imol), pbc(center, cell), cell)
                        END DO
                     END DO
                  END IF

                  CALL make_basic_spatial_distribution(pbc_scaled_coords, cluster_prices, &
                                                       nprow, cluster_row_distribution(:, 1), npcol, cluster_col_distribution(:, 1))

                  DEALLOCATE (pbc_scaled_coords)
               END IF
            END IF

            ! And assign back
            IF (.NOT. molecular_distribution) THEN
               row_distribution = cluster_row_distribution
               col_distribution = cluster_col_distribution
            ELSE
               imol = 0
               DO imolecule_kind = 1, nmolecule_kind
                  molecule_kind => molecule_kind_set(imolecule_kind)
                  CALL get_molecule_kind(molecule_kind=molecule_kind, molecule_list=molecule_list, natom=natom_mol)
                  DO imolecule = 1, SIZE(molecule_list)
                     imol = imol + 1
                     DO iatom_mol = 1, natom_mol
                        iatom = molecule_set(molecule_list(imolecule))%first_atom + iatom_mol - 1
                        row_distribution(iatom, :) = cluster_row_distribution(imol, :)
                        col_distribution(iatom, :) = cluster_col_distribution(imol, :)
                     END DO
                  END DO
               END DO
            END IF

            ! cleanup
            DEALLOCATE (cluster_list)
            DEALLOCATE (cluster_prices)
            DEALLOCATE (cluster_row_distribution)
            DEALLOCATE (cluster_col_distribution)

         ELSE
            ! expects nothing else
            CPABORT("")
         END IF

         ! prepare the lists of local particles

         ! count local particles of a given kind
         nparticle_local_col = 0
         nparticle_local_row = 0
         DO iatom = 1, natom
            CALL get_atomic_kind(atomic_kind=particle_set(iatom)%atomic_kind, kind_number=kind_a)
            IF (row_distribution(iatom, 1) == myprow) nparticle_local_row(kind_a) = nparticle_local_row(kind_a) + 1
            IF (col_distribution(iatom, 1) == mypcol) nparticle_local_col(kind_a) = nparticle_local_col(kind_a) + 1
         END DO

         ! allocate space
         DO iparticle_kind = 1, nparticle_kind
            n = nparticle_local_row(iparticle_kind)
            ALLOCATE (local_particle_row(iparticle_kind)%array(n))

            n = nparticle_local_col(iparticle_kind)
            ALLOCATE (local_particle_col(iparticle_kind)%array(n))
         END DO

         ! store
         nparticle_local_col = 0
         nparticle_local_row = 0
         DO iatom = 1, natom
            CALL get_atomic_kind(atomic_kind=particle_set(iatom)%atomic_kind, kind_number=kind_a)
            IF (row_distribution(iatom, 1) == myprow) THEN
               nparticle_local_row(kind_a) = nparticle_local_row(kind_a) + 1
               local_particle_row(kind_a)%array(nparticle_local_row(kind_a)) = iatom
            END IF
            IF (col_distribution(iatom, 1) == mypcol) THEN
               nparticle_local_col(kind_a) = nparticle_local_col(kind_a) + 1
               local_particle_col(kind_a)%array(nparticle_local_col(kind_a)) = iatom
            END IF
         END DO

!   *** Generate the 2d distribution structure  but take care of the zero offsets required
         row_distribution(:, 1) = row_distribution(:, 1) - 1
         col_distribution(:, 1) = col_distribution(:, 1) - 1
         CALL distribution_2d_create(distribution_2d, &
                                     row_distribution_ptr=row_distribution, &
                                     col_distribution_ptr=col_distribution, &
                                     local_rows_ptr=local_particle_row, &
                                     local_cols_ptr=local_particle_col, &
                                     blacs_env=blacs_env)

         NULLIFY (local_particle_row)
         NULLIFY (local_particle_col)
         NULLIFY (row_distribution)
         NULLIFY (col_distribution)

!   *** Print distribution, if requested ***
         IF (BTEST(cp_print_key_should_output(logger%iter_info, &
                                              force_env_section, "PRINT%DISTRIBUTION"), cp_p_file)) THEN

            output_unit = cp_print_key_unit_nr(logger, force_env_section, "PRINT%DISTRIBUTION", &
                                               extension=".Log")

!     *** Print row distribution ***

            ALLOCATE (work(nprow))
            work(:) = 0

            IF (mypcol == 1) work(myprow) = SUM(distribution_2d%n_local_rows)

            CALL group%sum(work)

            IF (output_unit > 0) THEN
               WRITE (UNIT=output_unit, &
                      FMT="(/, T2, A, /, T15, A, /, (T16, I10, T41, I10, T71, I10))") &
                  "DISTRIBUTION OF THE PARTICLES (ROWS)", &
                  "Process row      Number of particles         Number of matrix rows", &
                  (iprow - 1, work(iprow), -1, iprow=1, nprow)
               WRITE (UNIT=output_unit, FMT="(T23, A3, T41, I10, T71, I10)") &
                  "Sum", SUM(work), -1
               CALL m_flush(output_unit)
            END IF

            DEALLOCATE (work)

!     *** Print column distribution ***

            ALLOCATE (work(npcol))
            work(:) = 0

            IF (myprow == 1) work(mypcol) = SUM(distribution_2d%n_local_cols)

            CALL group%sum(work)

            IF (output_unit > 0) THEN
               WRITE (UNIT=output_unit, &
                      FMT="(/, T2, A, /, T15, A, /, (T16, I10, T41, I10, T71, I10))") &
                  "DISTRIBUTION OF THE PARTICLES (COLUMNS)", &
                  "Process col      Number of particles      Number of matrix columns", &
                  (ipcol - 1, work(ipcol), -1, ipcol=1, npcol)
               WRITE (UNIT=output_unit, FMT="(T23, A3, T41, I10, T71, I10)") &
                  "Sum", SUM(work), -1
               CALL m_flush(output_unit)
            END IF

            DEALLOCATE (work)

            CALL cp_print_key_finished_output(output_unit, logger, force_env_section, &
                                              "PRINT%DISTRIBUTION")
         END IF
      END ASSOCIATE

      IF (BTEST(cp_print_key_should_output(logger%iter_info, &
                                           force_env_section, "PRINT%DISTRIBUTION2D"), cp_p_file)) THEN

         iw = cp_logger_get_default_unit_nr(logger, LOCAL=.TRUE.)
         CALL distribution_2d_write(distribution_2d, &
                                    unit_nr=iw, &
                                    local=.TRUE., &
                                    long_description=.TRUE.)

      END IF

!   *** Release work storage ***

      DEALLOCATE (nparticle_local_row)

      DEALLOCATE (nparticle_local_col)

      CALL timestop(handle)

   END SUBROUTINE distribute_molecules_2d

! **************************************************************************************************
!> \brief Creates a basic distribution
!> \param cluster_list ...
!> \param cluster_prices ...
!> \param nprows ...
!> \param row_distribution ...
!> \param npcols ...
!> \param col_distribution ...
!> \par History
!> - Created 2010-08-06 UB
! **************************************************************************************************
   SUBROUTINE make_basic_distribution(cluster_list, cluster_prices, &
                                      nprows, row_distribution, npcols, col_distribution)
      INTEGER, DIMENSION(:), INTENT(INOUT)               :: cluster_list, cluster_prices
      INTEGER, INTENT(IN)                                :: nprows
      INTEGER, DIMENSION(:), INTENT(OUT)                 :: row_distribution
      INTEGER, INTENT(IN)                                :: npcols
      INTEGER, DIMENSION(:), INTENT(OUT)                 :: col_distribution

      CHARACTER(len=*), PARAMETER :: routineN = 'make_basic_distribution'

      INTEGER                                            :: bin, cluster, cluster_index, &
                                                            cluster_price, nbins, nclusters, pcol, &
                                                            pgrid_gcd, prow, timing_handle
      INTEGER(int_8)                                     :: bin_price
      LOGICAL                                            :: found
      TYPE(cp_heap_type)                                 :: bin_heap

!   ---------------------------------------------------------------------------

      CALL timeset(routineN, timing_handle)
      nbins = lcm(nprows, npcols)
      pgrid_gcd = gcd(nprows, npcols)
      CALL sort(cluster_prices, SIZE(cluster_list), cluster_list)
      CALL cp_heap_new(bin_heap, nbins)
      CALL cp_heap_fill(bin_heap, [(0_int_8, bin=1, nbins)])
      !
      nclusters = SIZE(cluster_list)
      ! Put the most expensive cluster in the bin with the smallest
      ! price and repeat.
      DO cluster_index = nclusters, 1, -1
         cluster = cluster_list(cluster_index)
         CALL cp_heap_get_first(bin_heap, bin, bin_price, found)
         IF (.NOT. found) &
            CPABORT("No topmost heap element found.")
         !
         prow = INT((bin - 1)*pgrid_gcd/npcols)
         IF (prow >= nprows) &
            CPABORT("Invalid process row.")
         pcol = INT((bin - 1)*pgrid_gcd/nprows)
         IF (pcol >= npcols) &
            CPABORT("Invalid process column.")
         row_distribution(cluster) = prow + 1
         col_distribution(cluster) = pcol + 1
         !
         cluster_price = cluster_prices(cluster_index)
         bin_price = bin_price + cluster_price
         CALL cp_heap_reset_first(bin_heap, bin_price)
      END DO
      CALL cp_heap_release(bin_heap)
      CALL timestop(timing_handle)
   END SUBROUTINE make_basic_distribution

! **************************************************************************************************
!> \brief Creates a basic spatial distribution
!>        that tries to make the corresponding blocks as homogeneous as possible
!> \param pbc_scaled_coords ...
!> \param costs ...
!> \param nprows ...
!> \param row_distribution ...
!> \param npcols ...
!> \param col_distribution ...
!> \par History
!> - Created 2010-11-11 Joost VandeVondele
! **************************************************************************************************
   SUBROUTINE make_basic_spatial_distribution(pbc_scaled_coords, costs, &
                                              nprows, row_distribution, npcols, col_distribution)
      REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: pbc_scaled_coords
      INTEGER, DIMENSION(:), INTENT(IN)                  :: costs
      INTEGER, INTENT(IN)                                :: nprows
      INTEGER, DIMENSION(:), INTENT(OUT)                 :: row_distribution
      INTEGER, INTENT(IN)                                :: npcols
      INTEGER, DIMENSION(:), INTENT(OUT)                 :: col_distribution

      CHARACTER(len=*), PARAMETER :: routineN = 'make_basic_spatial_distribution'

      INTEGER                                            :: handle, iatom, natoms, nbins, pgrid_gcd
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: bin_costs, distribution

      CALL timeset(routineN, handle)

      natoms = SIZE(costs)
      nbins = lcm(nprows, npcols)
      pgrid_gcd = gcd(nprows, npcols)
      ALLOCATE (bin_costs(nbins), distribution(natoms))
      bin_costs = 0

      CALL spatial_recurse(pbc_scaled_coords, costs, [(iatom, iatom=1, natoms)], bin_costs, distribution, 0)

      ! WRITE(*, *) "Final bin costs: ", bin_costs

      ! final row_distribution / col_distribution
      DO iatom = 1, natoms
         row_distribution(iatom) = (distribution(iatom) - 1)*pgrid_gcd/npcols + 1
         col_distribution(iatom) = (distribution(iatom) - 1)*pgrid_gcd/nprows + 1
      END DO

      DEALLOCATE (bin_costs, distribution)

      CALL timestop(handle)

   END SUBROUTINE make_basic_spatial_distribution

! **************************************************************************************************
!> \brief ...
!> \param pbc_scaled_coords ...
!> \param costs ...
!> \param indices ...
!> \param bin_costs ...
!> \param distribution ...
!> \param level ...
! **************************************************************************************************
   RECURSIVE SUBROUTINE spatial_recurse(pbc_scaled_coords, costs, indices, bin_costs, distribution, level)
      REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: pbc_scaled_coords
      INTEGER, DIMENSION(:), INTENT(IN)                  :: costs, indices
      INTEGER, DIMENSION(:), INTENT(INOUT)               :: bin_costs, distribution
      INTEGER, INTENT(IN)                                :: level

      INTEGER                                            :: iatom, ibin, natoms, nbins, nhalf
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: atom_costs_sorted, atom_permutation, &
                                                            bin_costs_sorted, permutation
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: coord

      natoms = SIZE(costs)
      nbins = SIZE(bin_costs)
      nhalf = (natoms + 1)/2

      IF (natoms <= nbins) THEN
         ! assign the most expensive atom to the least costly bin
         ALLOCATE (bin_costs_sorted(nbins), permutation(nbins))
         bin_costs_sorted(:) = bin_costs
         CALL sort(bin_costs_sorted, nbins, permutation)
         ALLOCATE (atom_costs_sorted(natoms), atom_permutation(natoms))
         atom_costs_sorted(:) = costs
         CALL sort(atom_costs_sorted, natoms, atom_permutation)
         ibin = 0
         ! WRITE(*, *) "Dealing with a new bunch of atoms "
         DO iatom = natoms, 1, -1
            ibin = ibin + 1
            ! WRITE(*, *) "atom", indices(atom_permutation(iatom)), "cost", atom_costs_sorted(iatom), &
            !            "bin", permutation(ibin), "its cost", bin_costs(permutation(ibin))
            ! WRITE(100, '(A, I0, 3F12.6)') "A", permutation(ibin), pbc_scaled_coords(:, atom_permutation(iatom))
            bin_costs(permutation(ibin)) = bin_costs(permutation(ibin)) + atom_costs_sorted(iatom)
            distribution(indices(atom_permutation(iatom))) = permutation(ibin)
         END DO
         DEALLOCATE (bin_costs_sorted, permutation, atom_costs_sorted, atom_permutation)
      ELSE
         ! divide atoms in two subsets, sorting according to their coordinates, alternatively x, y, z
         ! recursively do this for both subsets
         ALLOCATE (coord(natoms), permutation(natoms))
         coord(:) = pbc_scaled_coords(MOD(level, 3) + 1, :)
         CALL sort(coord, natoms, permutation)
         CALL spatial_recurse(pbc_scaled_coords(:, permutation(1:nhalf)), costs(permutation(1:nhalf)), &
                              indices(permutation(1:nhalf)), bin_costs, distribution, level + 1)
         CALL spatial_recurse(pbc_scaled_coords(:, permutation(nhalf + 1:)), costs(permutation(nhalf + 1:)), &
                              indices(permutation(nhalf + 1:)), bin_costs, distribution, level + 1)
         DEALLOCATE (coord, permutation)
      END IF

   END SUBROUTINE spatial_recurse

! **************************************************************************************************
!> \brief creates a distribution placing close by atoms into clusters and
!>        putting them on the same processors. Load balancing is
!>        performed by balancing sum of the cluster costs per processor
!> \param coords coordinates of the system
!> \param scaled_coords scaled coordinates
!> \param cell the cell_type
!> \param costs costs per atomic block
!> \param nprows number of precessors per row on the 2d grid
!> \param row_distribution the resulting distribution over proc_rows of atomic blocks
!> \param npcols number of precessors per col on the 2d grid
!> \param col_distribution the resulting distribution over proc_cols of atomic blocks
! **************************************************************************************************
   SUBROUTINE make_cluster_distribution(coords, scaled_coords, cell, costs, &
                                        nprows, row_distribution, npcols, col_distribution)
      REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: coords, scaled_coords
      TYPE(cell_type), POINTER                           :: cell
      INTEGER, DIMENSION(:), INTENT(IN)                  :: costs
      INTEGER, INTENT(IN)                                :: nprows
      INTEGER, DIMENSION(:, :), INTENT(OUT)              :: row_distribution
      INTEGER, INTENT(IN)                                :: npcols
      INTEGER, DIMENSION(:, :), INTENT(OUT)              :: col_distribution

      CHARACTER(len=*), PARAMETER :: routineN = 'make_cluster_distribution'

      INTEGER                                            :: handle, i, icluster, level, natom, &
                                                            output_unit
      INTEGER(KIND=int_8)                                :: ncluster
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: atom_to_cluster, cluster_cost, &
                                                            cluster_count, cluster_to_col, &
                                                            cluster_to_row, piv_cost, proc_cost, &
                                                            sorted_cost
      REAL(KIND=dp)                                      :: fold(3)
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: cluster_center, cluster_high, cluster_low

      CALL timeset(routineN, handle)

      output_unit = cp_logger_get_default_io_unit()

      natom = SIZE(costs)
      ncluster = dbcsr_distribution_get_num_images(SUM(costs), natom, nprows, npcols)
      ALLOCATE (atom_to_cluster(natom))
      ALLOCATE (cluster_cost(ncluster))
      ALLOCATE (cluster_to_row(ncluster))
      ALLOCATE (cluster_to_col(ncluster))
      ALLOCATE (sorted_cost(ncluster))
      ALLOCATE (piv_cost(ncluster))
      cluster_cost(:) = 0

      icluster = 0
      CALL cluster_recurse(coords, scaled_coords, cell, costs, atom_to_cluster, ncluster, icluster, cluster_cost)

      sorted_cost(:) = cluster_cost(:)
      CALL sort(sorted_cost, INT(ncluster), piv_cost)

      ALLOCATE (proc_cost(nprows))
      proc_cost = 0; level = 1
      CALL assign_clusters(cluster_cost, piv_cost, proc_cost, cluster_to_row, nprows)

      DEALLOCATE (proc_cost); ALLOCATE (proc_cost(npcols))
      proc_cost = 0; level = 1
      CALL assign_clusters(cluster_cost, piv_cost, proc_cost, cluster_to_col, npcols)

      DO i = 1, natom
         row_distribution(i, 1) = cluster_to_row(atom_to_cluster(i))
         row_distribution(i, 2) = atom_to_cluster(i)
         col_distribution(i, 1) = cluster_to_col(atom_to_cluster(i))
         col_distribution(i, 2) = atom_to_cluster(i)
      END DO

      ! generate some statistics on clusters
      ALLOCATE (cluster_center(3, ncluster))
      ALLOCATE (cluster_low(3, ncluster))
      ALLOCATE (cluster_high(3, ncluster))
      ALLOCATE (cluster_count(ncluster))
      cluster_count = 0
      DO i = 1, natom
         cluster_count(atom_to_cluster(i)) = cluster_count(atom_to_cluster(i)) + 1
         cluster_center(:, atom_to_cluster(i)) = coords(:, i)
      END DO
      cluster_low = HUGE(0.0_dp)/2
      cluster_high = -HUGE(0.0_dp)/2
      DO i = 1, natom
         fold = pbc(coords(:, i) - cluster_center(:, atom_to_cluster(i)), cell) + cluster_center(:, atom_to_cluster(i))
         cluster_low(:, atom_to_cluster(i)) = MIN(cluster_low(:, atom_to_cluster(i)), fold(:))
         cluster_high(:, atom_to_cluster(i)) = MAX(cluster_high(:, atom_to_cluster(i)), fold(:))
      END DO
      IF (output_unit > 0) THEN
         WRITE (output_unit, *)
         WRITE (output_unit, '(T2,A)') "Cluster distribution information"
         WRITE (output_unit, '(T2,A,T48,I8)') "Number of atoms", natom
         WRITE (output_unit, '(T2,A,T48,I8)') "Number of clusters", ncluster
         WRITE (output_unit, '(T2,A,T48,I8)') "Largest cluster in atoms", MAXVAL(cluster_count)
         WRITE (output_unit, '(T2,A,T48,I8)') "Smallest cluster in atoms", MINVAL(cluster_count)
         WRITE (output_unit, '(T2,A,T48,F8.3,I8)') "Largest cartesian extend [a.u.]/cluster x=", &
            MAXVAL(cluster_high(1, :) - cluster_low(1, :), MASK=(cluster_count > 0)), &
            MAXLOC(cluster_high(1, :) - cluster_low(1, :), MASK=(cluster_count > 0))
         WRITE (output_unit, '(T2,A,T48,F8.3,I8)') "Largest cartesian extend [a.u.]/cluster y=", &
            MAXVAL(cluster_high(2, :) - cluster_low(2, :), MASK=(cluster_count > 0)), &
            MAXLOC(cluster_high(2, :) - cluster_low(2, :), MASK=(cluster_count > 0))
         WRITE (output_unit, '(T2,A,T48,F8.3,I8)') "Largest cartesian extend [a.u.]/cluster z=", &
            MAXVAL(cluster_high(3, :) - cluster_low(3, :), MASK=(cluster_count > 0)), &
            MAXLOC(cluster_high(3, :) - cluster_low(3, :), MASK=(cluster_count > 0))
      END IF

      DEALLOCATE (atom_to_cluster, cluster_cost, cluster_to_row, cluster_to_col, sorted_cost, piv_cost, proc_cost)
      CALL timestop(handle)

   END SUBROUTINE make_cluster_distribution

! **************************************************************************************************
!> \brief assigns the clusters to processors, tryimg to balance the cost on the nodes
!> \param cluster_cost vector with the cost of each cluster
!> \param piv_cost pivoting vector sorting the cluster_cost
!> \param proc_cost cost per processor, on input 0 everywhere
!> \param cluster_assign assgnment of clusters on proc
!> \param nproc number of processor over which clusters are distributed
! **************************************************************************************************
   SUBROUTINE assign_clusters(cluster_cost, piv_cost, proc_cost, cluster_assign, nproc)
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: cluster_cost, piv_cost, proc_cost, &
                                                            cluster_assign
      INTEGER                                            :: nproc

      CHARACTER(len=*), PARAMETER                        :: routineN = 'assign_clusters'

      INTEGER                                            :: handle, i, ilevel, offset, &
                                                            piv_pcost(nproc), sort_proc_cost(nproc)

      CALL timeset(routineN, handle)

      DO ilevel = 1, SIZE(cluster_cost)/nproc
         sort_proc_cost(:) = proc_cost(:)
         CALL sort(sort_proc_cost, nproc, piv_pcost)

         offset = (SIZE(cluster_cost)/nproc - ilevel + 1)*nproc + 1
         DO i = 1, nproc
            cluster_assign(piv_cost(offset - i)) = piv_pcost(i)
            proc_cost(piv_pcost(i)) = proc_cost(piv_pcost(i)) + cluster_cost(piv_cost(offset - i))
         END DO
      END DO

      CALL timestop(handle)

   END SUBROUTINE assign_clusters

! **************************************************************************************************
!> \brief recursive routine to cluster atoms.
!>        Low level uses a modified KMEANS algorithm
!>        recursion is used to reduce cost.
!>        each level will subdivide a cluster into smaller clusters
!>        If only a single split is necessary atoms are assigned to the current cluster
!> \param coord coordinates of the system
!> \param scaled_coord scaled coordinates
!> \param cell the cell_type
!> \param costs costs per atomic block
!> \param cluster_inds the atom_to cluster mapping
!> \param ncluster number of clusters still to be created on a given recursion level
!> \param icluster the index of the current cluster to be created
!> \param fin_cluster_cost total cost of the final clusters
! **************************************************************************************************
   RECURSIVE SUBROUTINE cluster_recurse(coord, scaled_coord, cell, costs, cluster_inds, ncluster, icluster, fin_cluster_cost)
      REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: coord, scaled_coord
      TYPE(cell_type), POINTER                           :: cell
      INTEGER, DIMENSION(:), INTENT(IN)                  :: costs
      INTEGER, DIMENSION(:), INTENT(INOUT)               :: cluster_inds
      INTEGER(KIND=int_8), INTENT(INOUT)                 :: ncluster
      INTEGER, INTENT(INOUT)                             :: icluster
      INTEGER, DIMENSION(:), INTENT(INOUT)               :: fin_cluster_cost

      INTEGER                                            :: i, ibeg, iend, maxv(1), min_seed, &
                                                            natoms, nleft, nsplits, seed, tot_cost
      INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: ncluster_new
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: cluster_cost, inds_tmp, nat_cluster, piv
      LOGICAL                                            :: found
      REAL(KIND=dp)                                      :: balance, balance_new, conv

      natoms = SIZE(coord, 2)
      ! This is a bit of an arbitrary choice, simply a try to avoid too many clusters on large systems and too few for balancing on
      ! small systems or subclusters
      IF (natoms <= 1) THEN
         nsplits = 1
      ELSE
         nsplits = MIN(INT(MIN(INT(MAX(6, INT(60.00/LOG(REAL(natoms, KIND=dp)))), KIND=int_8), ncluster)), natoms)
      END IF
      IF (nsplits == 1) THEN
         icluster = icluster + 1
         cluster_inds = icluster
         fin_cluster_cost(icluster) = SUM(costs)
      ELSE
         ALLOCATE (cluster_cost(nsplits), ncluster_new(nsplits), inds_tmp(natoms), piv(natoms), nat_cluster(nsplits))
         ! initialise some values
         cluster_cost = 0; seed = 300; found = .TRUE.; min_seed = seed
         CALL kmeans(nsplits, coord, scaled_coord, cell, cluster_inds, nat_cluster, seed, conv)
         balance = MAXVAL(REAL(nat_cluster, KIND=dp))/MINVAL(REAL(nat_cluster, KIND=dp))

         ! If the system is small enough try to do better in terms of balancing number of atoms per cluster
         ! by changing the seed for the initial guess
         IF (natoms < 1000 .AND. balance > 1.1) THEN
            found = .FALSE.
            DO i = 1, 5
               IF (balance > 1.1) THEN
                  CALL kmeans(nsplits, coord, scaled_coord, cell, cluster_inds, nat_cluster, seed + i*40, conv)
                  balance_new = MAXVAL(REAL(nat_cluster, KIND=dp))/MINVAL(REAL(nat_cluster, KIND=dp))
                  IF (balance_new < balance) THEN
                     balance = balance_new
                     min_seed = seed + i*40;
                  END IF
               ELSE
                  found = .TRUE.
                  EXIT
               END IF
            END DO
         END IF
         !If we do not match the convergence than recompute at least the best assignment
         IF (.NOT. found) CALL kmeans(nsplits, coord, scaled_coord, cell, cluster_inds, nat_cluster, min_seed, conv)

         ! compute the cost of each cluster to decide how many splits have to be performed on the next lower level
         DO i = 1, natoms
            cluster_cost(cluster_inds(i)) = cluster_cost(cluster_inds(i)) + costs(i)
         END DO
         tot_cost = SUM(cluster_cost)
         ! compute new splitting, can be done more elegant
         ncluster_new(:) = ncluster*cluster_cost(:)/tot_cost
         nleft = INT(ncluster - SUM(ncluster_new))
         ! As we won't have empty clusters, we can not have 0 as new size, so we correct for this at first
         DO i = 1, nsplits
            IF (ncluster_new(i) == 0) THEN
               ncluster_new(i) = 1
               nleft = nleft - 1
            END IF
         END DO
         ! now comes the next part that the number of clusters will not match anymore, so try to correct in a meaningful way without
         ! introducing 0 sized blocks again
         IF (nleft /= 0) THEN
            DO i = 1, ABS(nleft)
               IF (nleft < 0) THEN
                  maxv = MINLOC(cluster_cost/ncluster_new)
                  IF (ncluster_new(maxv(1)) /= 1) THEN
                     ncluster_new(maxv) = ncluster_new(maxv) - 1
                  ELSE
                     maxv = MAXLOC(ncluster_new)
                     ncluster_new(maxv) = ncluster_new(maxv) - 1
                  END IF
               ELSE
                  maxv = MAXLOC(cluster_cost/ncluster_new)
                  ncluster_new(maxv) = ncluster_new(maxv) + 1
               END IF
            END DO
         END IF

         !Now get the permutations to sort the atoms in the nsplits clusters for the next level of iteration
         inds_tmp(:) = cluster_inds(:)
         CALL sort(inds_tmp, natoms, piv)

         ibeg = 1; iend = 0
         DO i = 1, nsplits
            IF (nat_cluster(i) == 0) CYCLE
            iend = iend + nat_cluster(i)
            CALL cluster_recurse(coord(:, piv(ibeg:iend)), scaled_coord(:, piv(ibeg:iend)), cell, costs(piv(ibeg:iend)), &
                                 inds_tmp(ibeg:iend), ncluster_new(i), icluster, fin_cluster_cost)
            ibeg = ibeg + nat_cluster(i)
         END DO
         ! copy the sorted cluster IDs on the old layout, inds_tmp gets set at the lowest level of recursion
         cluster_inds(piv(:)) = inds_tmp
         DEALLOCATE (cluster_cost, ncluster_new, inds_tmp, piv, nat_cluster)

      END IF

   END SUBROUTINE cluster_recurse

! **************************************************************************************************
!> \brief A modified version of the kmeans algorithm.
!>        The assignment has a penalty function in case clusters become
!>        larger than average. Like this more even sized clusters are created
!>        trading it for locality
!> \param ncent number of centers to be created
!> \param coord coordinates
!> \param scaled_coord scaled coord
!> \param cell the cell_type
!> \param cluster atom to cluster assignment
!> \param nat_cl atoms per cluster
!> \param seed seed for the RNG. Algorithm might need multiple tries to deliver best results
!> \param tot_var the total variance of the clusters around the centers
! **************************************************************************************************
   SUBROUTINE kmeans(ncent, coord, scaled_coord, cell, cluster, nat_cl, seed, tot_var)
      INTEGER                                            :: ncent
      REAL(KIND=dp), DIMENSION(:, :)                     :: coord, scaled_coord
      TYPE(cell_type), POINTER                           :: cell
      INTEGER, DIMENSION(:)                              :: cluster, nat_cl
      INTEGER                                            :: seed
      REAL(KIND=dp)                                      :: tot_var

      CHARACTER(len=*), PARAMETER                        :: routineN = 'kmeans'

      INTEGER                                            :: handle, i, ind, itn, j, nat, oldc
      LOGICAL                                            :: changed
      REAL(KIND=dp) :: average(3, ncent, 2), cent_coord(3, ncent), devi, deviat(ncent), dist, &
         dvec(3), old_var, rn, scaled_cent(3, ncent), var_cl(ncent)
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: dmat
      REAL(KIND=dp), DIMENSION(3, 2)                     :: initial_seed
      TYPE(rng_stream_type)                              :: rng_stream

      CALL timeset(routineN, handle)

      initial_seed = REAL(seed, dp); nat = SIZE(coord, 2)
      ALLOCATE (dmat(ncent, nat))

      rng_stream = rng_stream_type(name="kmeans uniform distribution [0,1]", &
                                   distribution_type=UNIFORM, seed=initial_seed)

! try to find a clever initial guess with centers being somewhat distributed
      rn = rng_stream%next()
      ind = CEILING(rn*nat)
      cent_coord(:, 1) = coord(:, ind)
      DO i = 2, ncent
         DO
            rn = rng_stream%next()
            ind = CEILING(rn*nat)
            cent_coord(:, i) = coord(:, ind)
            devi = HUGE(1.0_dp)
            DO j = 1, i - 1
               dvec = pbc(cent_coord(:, j), cent_coord(:, i), cell)
               dist = SQRT(DOT_PRODUCT(dvec, dvec))
               IF (dist < devi) devi = dist
            END DO
            rn = rng_stream%next()
            IF (rn < devi**2/169.0) EXIT
         END DO
      END DO

! Now start the KMEANS but penalise it in case it starts packing too many atoms into a single set
! Unfoirtunatelz as this is dependent on what happened before it cant be parallel
      cluster = 0; old_var = HUGE(1.0_dp)
      DO itn = 1, 1000
         changed = .FALSE.; var_cl = 0.0_dp; tot_var = 0.0_dp; nat_cl = 0; deviat = 0.0_dp
!      !$OMP PARALLEL DO PRIVATE(i,j,dvec)
         DO i = 1, nat
            DO j = 1, ncent
               dvec = pbc(cent_coord(:, j), coord(:, i), cell)
               dmat(j, i) = DOT_PRODUCT(dvec, dvec)
            END DO
         END DO
         DO i = 1, nat
            devi = HUGE(1.0_dp); oldc = cluster(i)
            DO j = 1, ncent
               dist = dmat(j, i) + MAX(nat_cl(j)**2/nat*ncent, nat/ncent)
               IF (dist < devi) THEN
                  devi = dist; cluster(i) = j
               END IF
            END DO
            deviat(cluster(i)) = deviat(cluster(i)) + SQRT(devi)
            nat_cl(cluster(i)) = nat_cl(cluster(i)) + 1
            tot_var = tot_var + devi
            IF (oldc /= cluster(i)) changed = .TRUE.
         END DO
         ! get the update of the centers done, add a new one in case one center lost all its atoms
         ! the algorithm would survive, but its nice to really create what you demand
         IF (tot_var >= old_var) EXIT
         IF (changed) THEN
            ! Here misery of computing the center of geometry of the clusters in PBC.
            ! The mapping on the unit circle allows to circumvent all problems
            average = 0.0_dp
            DO i = 1, SIZE(coord, 2)
               average(:, cluster(i), 1) = average(:, cluster(i), 1) + COS(scaled_coord(:, i)*2.0_dp*pi)
               average(:, cluster(i), 2) = average(:, cluster(i), 2) + SIN(scaled_coord(:, i)*2.0_dp*pi)
            END DO

            DO i = 1, ncent
               IF (nat_cl(i) == 0) THEN
                  rn = rng_stream%next()
                  scaled_cent(:, i) = scaled_coord(:, CEILING(rn*nat))
               ELSE
                  average(:, i, 1) = average(:, i, 1)/REAL(nat_cl(i), dp)
                  average(:, i, 2) = average(:, i, 2)/REAL(nat_cl(i), dp)
                  scaled_cent(:, i) = (ATAN2(-average(:, i, 2), -average(:, i, 1)) + pi)/(2.0_dp*pi)
                  CALL scaled_to_real(cent_coord(:, i), scaled_cent(:, i), cell)
               END IF
            END DO
         ELSE
            EXIT
         END IF
      END DO

      CALL timestop(handle)

   END SUBROUTINE kmeans

END MODULE distribution_methods
